from copy import deepcopy
import os
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
from tensorboardX import SummaryWriter

from losses.square_loss import MSELoss

import argparse

torch.set_default_tensor_type(torch.DoubleTensor)


parser = argparse.ArgumentParser(description='Numerical Results')
parser.add_argument('--p', default=512, type=int, help='the dimension of features and protoypes')
parser.add_argument('--num-classes', default=100, type=int, help='the number of classes')
parser.add_argument('--num-per-class', default=10, type=int, help='the number of sample in each class')
parser.add_argument('--lamb', default=None, type=float)
parser.add_argument('--gamma', default=0.0, type=float)
parser.add_argument('--lr', default=0.5, type=float, help='learning rate')
parser.add_argument('--adjusted', default=False, action='store_true')
parser.add_argument('--rescaled', default=False, action='store_true')
parser.add_argument('--seed', default=123, type=int, help='random seed')
parser.add_argument('--exp', default='exp', type=str)

args = parser.parse_args()


os.environ['CUDA_VISIBLE_DEVICES'] = '0'
torch.backends.cudnn.enabled =True
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.determinstic = True
device = 'cuda' if torch.cuda.is_available() else 'cpu'


random.seed(args.seed)
np.random.seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.manual_seed(args.seed)

C = args.num_classes
N = args.num_per_class
CN = C * N
p = args.p


class AveragedSampleMarginLossLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super().__init__()
        self.alpha = alpha
    
    def forward(self, logits, labels):
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        l1 = torch.sum(logits * label_one_hot, dim=-1)
        l2 = torch.sum(logits * (1 - label_one_hot), dim=-1)
        if self.alpha > 0:
            loss = -l1 + self.alpha * l2
        else:
            loss = -l1
        return loss.mean()

def evaluate(out, labels):
    pred = torch.argmax(out, 1)
    total = labels.size(0)
    correct = (pred==labels).sum().item()
    acc = float(correct) / float(total)
    return acc


def get_margin(weight):
    tmp = F.normalize(weight, dim=1)
    similarity = torch.matmul(tmp, tmp.transpose(1, 0)) - 2 * torch.eye(tmp.size(0), device=weight.device)
    similarity = torch.clamp(similarity, -1+1e-7, 1-1e-7)
    return torch.acos(torch.max(similarity)).item() / math.pi * 180

def norm_of_weight(weight):
    return torch.sum(weight**2).item()



labels = [i for i in range(C)] * N
labels = torch.LongTensor(labels).to(device)
H = torch.randn(CN, p).to(device)
W = torch.randn(C, p).to(device)
# nn.init.kaiming_uniform_(W)
W = W - torch.mean(W, dim=-1, keepdim=True)
print(torch.sum(W))


H.requires_grad = True
W.requires_grad = False

# print(H.dtype)

lr = args.lr
gamma = 1. / (C - 1) if args.gamma is None else args.gamma
# lamb = (1 + gamma) / (C * math.sqrt(N)) if args.lamb is None else args.lamb
lamb = 0.0

H_target = W.repeat((N, 1))

optimizer = torch.optim.SGD(
    [
        {'params': H, 'lr': lr},
    ],
    weight_decay=lamb)

criterion = AveragedSampleMarginLossLoss(alpha=gamma)
# criterion = MSELoss()
# criterion = nn.CrossEntropyLoss()

print(args.exp)

eps = 0.8
store_name = './log/spherical/' + args.exp + '/dim={}, C={}, N={}, lambda={}, gamma={}, lr={}, adjusted={}, rescaled={}'.format(p, C, N, lamb, gamma, lr, args.adjusted, args.rescaled)
tf_writer = SummaryWriter(log_dir=store_name)

acc_list = []
m_list = []
norm_w_list = []
norm_h_list = []
error_list = []


epochs = 50000
for ep in range(epochs):
    logits = F.linear(F.normalize(H, dim=-1), W)

    if args.adjusted:
        normalized_logits = F.linear(F.normalize(H, dim=-1), F.normalize(W, dim=-1))
        label_one_hot = F.one_hot(labels, logits.size()[1]).float().to(logits.device)
        target_logits = normalized_logits * label_one_hot
        mask = (target_logits < -1 + eps).float()
        mask_logits = normalized_logits * mask
        c = -(1 + math.sqrt(1-(1-eps)**2)/(1-eps) * mask_logits / torch.sqrt(1-mask_logits**2)) * mask
        logits = logits + c.detach() * logits

    if args.rescaled:
        feature_norms = torch.norm(H, dim=1).view(-1, 1).repeat(1, logits.size(1))
        logits = feature_norms.detach() * logits

    loss = criterion(logits, labels)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    acc = evaluate(logits, labels)
    margin = get_margin(W)
    norm_w = norm_of_weight(W)
    norm_f = norm_of_weight(H)

    error = torch.sum((F.normalize(H, dim=-1) - F.normalize(H_target, dim=-1))**2).item()

    acc_list.append(acc)
    m_list.append(margin)
    norm_w_list.append(norm_w)
    norm_h_list.append(norm_f)
    error_list.append(error)

    tf_writer.add_scalar('acc', acc, ep)
    tf_writer.add_scalar('margin', margin, ep)
    tf_writer.add_scalar('W', norm_w, ep)
    tf_writer.add_scalar('H', norm_f, ep)
    tf_writer.add_scalar('err', error, ep)
    if ep % 200 ==0:
        print('Iter {}: loss={:.4f}, acc={:.4f}, margin={:.4f}, norm_w={:.4f}, norm_f={:.4f}, error={:.4f}'.format(ep, loss.item(), acc, margin, norm_w, norm_f, error))
    torch.cuda.empty_cache()

acc_list = np.array(acc_list)
m_list = np.array(m_list)
norm_w_list = np.array(norm_w_list)
norm_h_list = np.array(norm_h_list)
error_list = np.array(error_list)

np.save(store_name+'/acc.npy', acc_list)
np.save(store_name+'/margin.npy', m_list)
np.save(store_name+'/norm_w.npy', norm_w_list)
np.save(store_name+'/norm_h.npy', norm_h_list)
np.save(store_name+'/error.npy', error_list)